{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deep Gaussian Processes with Doubly Stochastic VI\n",
    "\n",
    "In this notebook, we provide a GPyTorch implementation of deep Gaussian processes, where training and inference is performed using the method of Salimbeni et al., 2017 (https://arxiv.org/abs/1705.08933) adapted to CG-based inference.\n",
    "\n",
    "We'll be training a simple two layer deep GP on the `elevators` UCI dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: CUDA_VISIBLE_DEVICES=0\n"
     ]
    }
   ],
   "source": [
    "%set_env CUDA_VISIBLE_DEVICES=0\n",
    "\n",
    "import torch\n",
    "import gpytorch\n",
    "from torch.nn import Linear\n",
    "from gpytorch.means import ConstantMean\n",
    "from gpytorch.kernels import RBFKernel, ScaleKernel\n",
    "from gpytorch.variational import VariationalStrategy, CholeskyVariationalDistribution\n",
    "from gpytorch.distributions import MultivariateNormal\n",
    "from gpytorch.models import AbstractVariationalGP, GP\n",
    "from gpytorch.mlls import VariationalELBO, AddedLossTerm\n",
    "from gpytorch.likelihoods import GaussianLikelihood\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gpytorch.models.deep_gps import AbstractDeepGPLayer, AbstractDeepGP, DeepLikelihood "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading Data\n",
    "\n",
    "For this example notebook, we'll be using the `elevators` UCI dataset used in the paper. Running the next cell downloads a copy of the dataset that has already been scaled and normalized appropriately. For this notebook, we'll simply be splitting the data using the first 80% of the data as training and the last 20% as testing.\n",
    "\n",
    "**Note**: Running the next cell will attempt to download a ~400 KB dataset file to the current directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import urllib.request\n",
    "import os.path\n",
    "from scipy.io import loadmat\n",
    "from math import floor\n",
    "import numpy as np\n",
    "\n",
    "if not os.path.isfile('elevators.mat'):\n",
    "    print('Downloading \\'elevators\\' UCI dataset...')\n",
    "    urllib.request.urlretrieve('https://drive.google.com/uc?export=download&id=1jhWL3YUHvXIaftia4qeAyDwVxo6j1alk', 'elevators.mat')\n",
    "    \n",
    "data = torch.Tensor(loadmat('elevators.mat')['data'])\n",
    "X = data[:, :-1]\n",
    "y = data[:, -1]\n",
    "\n",
    "N = data.shape[0]\n",
    "np.random.seed(0)\n",
    "data = data[np.random.permutation(np.arange(N)),:]\n",
    "\n",
    "train_n = int(floor(0.8*len(X)))\n",
    "\n",
    "train_x = X[:train_n, :].contiguous().cuda()\n",
    "train_y = y[:train_n].contiguous().cuda()\n",
    "\n",
    "test_x = X[train_n:, :].contiguous().cuda()\n",
    "test_y = y[train_n:].contiguous().cuda()\n",
    "\n",
    "mean = train_x.mean(dim=-2, keepdim=True)\n",
    "std = train_x.std(dim=-2, keepdim=True) + 1e-6\n",
    "train_x = (train_x - mean) / std\n",
    "test_x = (test_x - mean) / std\n",
    "\n",
    "mean,std = train_y.mean(),train_y.std()\n",
    "train_y = (train_y - mean) / std\n",
    "test_y = (test_y - mean) / std"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "train_dataset = TensorDataset(train_x, train_y)\n",
    "train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Defining GP layers\n",
    "\n",
    "In GPyTorch, defining a GP involves extending one of our abstract GP models and defining a `forward` method that returns the prior. For deep GPs, things are similar, but there are two abstract GP models that must be overwritten: one for hidden layers and one for the deep GP model itself.\n",
    "\n",
    "In the next cell, we define an example deep GP hidden layer. This looks very similar to every other variational GP you might define. However, there are a few key differences:\n",
    "\n",
    "1. Instead of extending `AbstractVariationalGP`, we extend `AbstractDeepGPLayer`.\n",
    "2. `AbstractDeepGPLayers` need a number of input dimensions, a number of output dimensions, and a number of samples. This is kind of like a linear layer in a standard neural network -- `input_dims` defines how many inputs this hidden layer will expect, and `output_dims` defines how many hidden GPs to create outputs for."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ToyDeepGPHiddenLayer(AbstractDeepGPLayer):\n",
    "    def __init__(self, input_dims, output_dims, num_inducing=512):\n",
    "        if output_dims is None:\n",
    "            inducing_points = torch.randn(num_inducing, input_dims)\n",
    "        else:\n",
    "            inducing_points = torch.randn(output_dims, num_inducing, input_dims)\n",
    "\n",
    "        variational_distribution = CholeskyVariationalDistribution(\n",
    "            num_inducing_points=num_inducing,\n",
    "            batch_shape=torch.Size([output_dims]) if output_dims is not None else torch.Size([])\n",
    "        )\n",
    "\n",
    "        variational_strategy = VariationalStrategy(\n",
    "            self,\n",
    "            inducing_points,\n",
    "            variational_distribution,\n",
    "            learn_inducing_locations=True\n",
    "        )\n",
    "\n",
    "        super(ToyDeepGPHiddenLayer, self).__init__(variational_strategy, input_dims, output_dims)\n",
    "\n",
    "        self.mean_module = ConstantMean(batch_size=output_dims)\n",
    "        self.covar_module = ScaleKernel(\n",
    "            RBFKernel(batch_size=output_dims, ard_num_dims=input_dims),\n",
    "            batch_size=output_dims, ard_num_dims=None\n",
    "        )\n",
    "        \n",
    "        self.linear_layer = Linear(input_dims, 1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        mean_x = self.mean_module(x) # self.linear_layer(x).squeeze(-1)\n",
    "        covar_x = self.covar_module(x)\n",
    "        return MultivariateNormal(mean_x, covar_x)\n",
    "    \n",
    "    def __call__(self, x, *other_inputs, **kwargs):\n",
    "        \"\"\"\n",
    "        Overriding __call__ isn't strictly necessary, but it lets us add concatenation based skip connections\n",
    "        easily. For example, hidden_layer2(hidden_layer1_outputs, inputs) will pass the concatenation of the first\n",
    "        hidden layer's outputs and the input data to hidden_layer2.\n",
    "        \"\"\"\n",
    "        if len(other_inputs):\n",
    "            if isinstance(x, gpytorch.distributions.MultitaskMultivariateNormal):\n",
    "                x = x.rsample()\n",
    "\n",
    "            processed_inputs = [\n",
    "                inp.unsqueeze(0).expand(self.num_samples, *inp.shape)\n",
    "                for inp in other_inputs\n",
    "            ]\n",
    "\n",
    "            x = torch.cat([x] + processed_inputs, dim=-1)\n",
    "\n",
    "        return super().__call__(x, are_samples=bool(len(other_inputs)))\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Building the model\n",
    "\n",
    "Now that we've defined a class for our hidden layers and a class for our output layer, we can build our deep GP. To do this, we create a `Module` whose forward is simply responsible for forwarding through the various layers.\n",
    "\n",
    "This also allows for various network connectivities easily. For example calling,\n",
    "```\n",
    "hidden_rep2 = self.second_hidden_layer(hidden_rep1, inputs)\n",
    "```\n",
    "in forward would cause the second hidden layer to use both the output of the first hidden layer and the input data as inputs, concatenating the two together."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DeepGP(AbstractDeepGP):\n",
    "    def __init__(self, train_x_shape):\n",
    "        hidden_layer = ToyDeepGPHiddenLayer(\n",
    "            input_dims=train_x_shape[-1],\n",
    "            output_dims=10\n",
    "        )\n",
    "        \n",
    "        last_layer = ToyDeepGPHiddenLayer(\n",
    "            input_dims=hidden_layer.output_dims,\n",
    "            output_dims=None,\n",
    "        )\n",
    "        \n",
    "        super().__init__()\n",
    "        \n",
    "        self.hidden_layer = hidden_layer\n",
    "        self.last_layer = last_layer\n",
    "        self.likelihood = DeepLikelihood(GaussianLikelihood())\n",
    "    \n",
    "    def forward(self, inputs):\n",
    "        hidden_rep1 = self.hidden_layer(inputs)\n",
    "        output = self.last_layer(hidden_rep1)\n",
    "        return output\n",
    "    \n",
    "    def predict(self, x):\n",
    "        with gpytorch.settings.fast_computations(log_prob=False, solves=False), torch.no_grad():\n",
    "            preds = self(x)\n",
    "        predictive_means = preds.mean\n",
    "        predictive_variances = preds.variance\n",
    "        \n",
    "        return predictive_means, predictive_variances"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = DeepGP(train_x.shape).cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Likelihood\n",
    "\n",
    "Because deep GPs use some amounts of internal sampling (even in the stochastic variational setting), we need to handle the likelihood in a slightly different way. In the future, we anticipate `DeepLikelihood` being a general wrapper around an arbitrary likelihood once likelihoods become a little more general purpose, but for now we simply define a `DeepGaussianLikelihood` to use for regression."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training the model\n",
    "\n",
    "The training loop for a deep GP looks similar to a standard GP model with stochastic variational inference, but there are a few differences:\n",
    "\n",
    "1. Because the output of a deep GP is actually num_outputs x num_samples Gaussians rather than a single Gaussian, we need to expand the labels to be num_outputs x num_samples x minibatch_size before calling the ELBO.\n",
    "2. Because deep GPs involve a few added loss terms and normalize slightly differently, we created the `VariationalELBO` above with `combine_terms=False`. This just lets us do the extra normalization we need to make the math work out."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1 [0/13] - Loss: 1.954 - - Time: 0.450\n",
      "Epoch 1 [1/13] - Loss: 1.952 - - Time: 0.111\n",
      "Epoch 1 [2/13] - Loss: 1.935 - - Time: 0.101\n",
      "Epoch 1 [3/13] - Loss: 1.922 - - Time: 0.097\n",
      "Epoch 1 [4/13] - Loss: 1.914 - - Time: 0.107\n",
      "Epoch 1 [5/13] - Loss: 1.880 - - Time: 0.103\n",
      "Epoch 1 [6/13] - Loss: 1.875 - - Time: 0.118\n",
      "Epoch 1 [7/13] - Loss: 1.870 - - Time: 0.105\n",
      "Epoch 1 [8/13] - Loss: 1.957 - - Time: 0.110\n",
      "Epoch 1 [9/13] - Loss: 1.874 - - Time: 0.112\n",
      "Epoch 1 [10/13] - Loss: 1.934 - - Time: 0.101\n",
      "Epoch 1 [11/13] - Loss: 1.869 - - Time: 0.099\n",
      "Epoch 1 [12/13] - Loss: 1.853 - - Time: 0.102\n",
      "Epoch 2 [0/13] - Loss: 1.848 - - Time: 0.102\n",
      "Epoch 2 [1/13] - Loss: 1.818 - - Time: 0.101\n",
      "Epoch 2 [2/13] - Loss: 1.882 - - Time: 0.111\n",
      "Epoch 2 [3/13] - Loss: 1.828 - - Time: 0.114\n",
      "Epoch 2 [4/13] - Loss: 1.806 - - Time: 0.123\n",
      "Epoch 2 [5/13] - Loss: 1.730 - - Time: 0.103\n",
      "Epoch 2 [6/13] - Loss: 1.792 - - Time: 0.105\n",
      "Epoch 2 [7/13] - Loss: 1.822 - - Time: 0.107\n",
      "Epoch 2 [8/13] - Loss: 1.859 - - Time: 0.100\n",
      "Epoch 2 [9/13] - Loss: 1.756 - - Time: 0.100\n",
      "Epoch 2 [10/13] - Loss: 1.765 - - Time: 0.114\n",
      "Epoch 2 [11/13] - Loss: 1.720 - - Time: 0.099\n",
      "Epoch 2 [12/13] - Loss: 1.890 - - Time: 0.103\n",
      "Epoch 3 [0/13] - Loss: 1.729 - - Time: 0.104\n",
      "Epoch 3 [1/13] - Loss: 1.812 - - Time: 0.096\n",
      "Epoch 3 [2/13] - Loss: 1.808 - - Time: 0.100\n",
      "Epoch 3 [3/13] - Loss: 1.689 - - Time: 0.103\n",
      "Epoch 3 [4/13] - Loss: 1.682 - - Time: 0.105\n",
      "Epoch 3 [5/13] - Loss: 1.725 - - Time: 0.096\n",
      "Epoch 3 [6/13] - Loss: 1.710 - - Time: 0.104\n",
      "Epoch 3 [7/13] - Loss: 1.697 - - Time: 0.110\n",
      "Epoch 3 [8/13] - Loss: 1.744 - - Time: 0.108\n",
      "Epoch 3 [9/13] - Loss: 1.701 - - Time: 0.107\n",
      "Epoch 3 [10/13] - Loss: 1.709 - - Time: 0.099\n",
      "Epoch 3 [11/13] - Loss: 1.666 - - Time: 0.100\n",
      "Epoch 3 [12/13] - Loss: 1.710 - - Time: 0.101\n",
      "Epoch 4 [0/13] - Loss: 1.679 - - Time: 0.100\n",
      "Epoch 4 [1/13] - Loss: 1.677 - - Time: 0.100\n",
      "Epoch 4 [2/13] - Loss: 1.648 - - Time: 0.101\n",
      "Epoch 4 [3/13] - Loss: 1.677 - - Time: 0.099\n",
      "Epoch 4 [4/13] - Loss: 1.716 - - Time: 0.101\n",
      "Epoch 4 [5/13] - Loss: 1.644 - - Time: 0.097\n",
      "Epoch 4 [6/13] - Loss: 1.634 - - Time: 0.112\n",
      "Epoch 4 [7/13] - Loss: 1.617 - - Time: 0.099\n",
      "Epoch 4 [8/13] - Loss: 1.649 - - Time: 0.106\n",
      "Epoch 4 [9/13] - Loss: 1.601 - - Time: 0.103\n",
      "Epoch 4 [10/13] - Loss: 1.594 - - Time: 0.099\n",
      "Epoch 4 [11/13] - Loss: 1.627 - - Time: 0.096\n",
      "Epoch 4 [12/13] - Loss: 1.574 - - Time: 0.101\n",
      "Epoch 5 [0/13] - Loss: 1.681 - - Time: 0.103\n",
      "Epoch 5 [1/13] - Loss: 1.610 - - Time: 0.100\n",
      "Epoch 5 [2/13] - Loss: 1.599 - - Time: 0.102\n",
      "Epoch 5 [3/13] - Loss: 1.583 - - Time: 0.101\n",
      "Epoch 5 [4/13] - Loss: 1.540 - - Time: 0.101\n",
      "Epoch 5 [5/13] - Loss: 1.529 - - Time: 0.100\n",
      "Epoch 5 [6/13] - Loss: 1.611 - - Time: 0.102\n",
      "Epoch 5 [7/13] - Loss: 1.622 - - Time: 0.102\n",
      "Epoch 5 [8/13] - Loss: 1.522 - - Time: 0.107\n",
      "Epoch 5 [9/13] - Loss: 1.524 - - Time: 0.110\n",
      "Epoch 5 [10/13] - Loss: 1.545 - - Time: 0.107\n",
      "Epoch 5 [11/13] - Loss: 1.539 - - Time: 0.108\n",
      "Epoch 5 [12/13] - Loss: 1.559 - - Time: 0.100\n",
      "Epoch 6 [0/13] - Loss: 1.588 - - Time: 0.096\n",
      "Epoch 6 [1/13] - Loss: 1.560 - - Time: 0.102\n",
      "Epoch 6 [2/13] - Loss: 1.514 - - Time: 0.103\n",
      "Epoch 6 [3/13] - Loss: 1.546 - - Time: 0.100\n",
      "Epoch 6 [4/13] - Loss: 1.576 - - Time: 0.104\n",
      "Epoch 6 [5/13] - Loss: 1.514 - - Time: 0.096\n",
      "Epoch 6 [6/13] - Loss: 1.524 - - Time: 0.105\n",
      "Epoch 6 [7/13] - Loss: 1.492 - - Time: 0.102\n",
      "Epoch 6 [8/13] - Loss: 1.531 - - Time: 0.101\n",
      "Epoch 6 [9/13] - Loss: 1.498 - - Time: 0.102\n",
      "Epoch 6 [10/13] - Loss: 1.539 - - Time: 0.109\n",
      "Epoch 6 [11/13] - Loss: 1.441 - - Time: 0.099\n",
      "Epoch 6 [12/13] - Loss: 1.544 - - Time: 0.098\n",
      "Epoch 7 [0/13] - Loss: 1.520 - - Time: 0.102\n",
      "Epoch 7 [1/13] - Loss: 1.504 - - Time: 0.097\n",
      "Epoch 7 [2/13] - Loss: 1.557 - - Time: 0.099\n",
      "Epoch 7 [3/13] - Loss: 1.532 - - Time: 0.101\n",
      "Epoch 7 [4/13] - Loss: 1.474 - - Time: 0.095\n",
      "Epoch 7 [5/13] - Loss: 1.536 - - Time: 0.100\n",
      "Epoch 7 [6/13] - Loss: 1.514 - - Time: 0.100\n",
      "Epoch 7 [7/13] - Loss: 1.487 - - Time: 0.105\n",
      "Epoch 7 [8/13] - Loss: 1.506 - - Time: 0.103\n",
      "Epoch 7 [9/13] - Loss: 1.465 - - Time: 0.105\n",
      "Epoch 7 [10/13] - Loss: 1.477 - - Time: 0.104\n",
      "Epoch 7 [11/13] - Loss: 1.479 - - Time: 0.102\n",
      "Epoch 7 [12/13] - Loss: 1.450 - - Time: 0.101\n",
      "Epoch 8 [0/13] - Loss: 1.481 - - Time: 0.107\n",
      "Epoch 8 [1/13] - Loss: 1.453 - - Time: 0.102\n",
      "Epoch 8 [2/13] - Loss: 1.466 - - Time: 0.101\n",
      "Epoch 8 [3/13] - Loss: 1.467 - - Time: 0.097\n",
      "Epoch 8 [4/13] - Loss: 1.504 - - Time: 0.099\n",
      "Epoch 8 [5/13] - Loss: 1.502 - - Time: 0.097\n",
      "Epoch 8 [6/13] - Loss: 1.486 - - Time: 0.100\n",
      "Epoch 8 [7/13] - Loss: 1.477 - - Time: 0.098\n",
      "Epoch 8 [8/13] - Loss: 1.481 - - Time: 0.110\n",
      "Epoch 8 [9/13] - Loss: 1.464 - - Time: 0.099\n",
      "Epoch 8 [10/13] - Loss: 1.477 - - Time: 0.099\n",
      "Epoch 8 [11/13] - Loss: 1.504 - - Time: 0.104\n",
      "Epoch 8 [12/13] - Loss: 1.521 - - Time: 0.098\n",
      "Epoch 9 [0/13] - Loss: 1.475 - - Time: 0.103\n",
      "Epoch 9 [1/13] - Loss: 1.525 - - Time: 0.102\n",
      "Epoch 9 [2/13] - Loss: 1.494 - - Time: 0.098\n",
      "Epoch 9 [3/13] - Loss: 1.451 - - Time: 0.101\n",
      "Epoch 9 [4/13] - Loss: 1.506 - - Time: 0.104\n",
      "Epoch 9 [5/13] - Loss: 1.463 - - Time: 0.098\n",
      "Epoch 9 [6/13] - Loss: 1.482 - - Time: 0.103\n",
      "Epoch 9 [7/13] - Loss: 1.438 - - Time: 0.106\n",
      "Epoch 9 [8/13] - Loss: 1.430 - - Time: 0.110\n",
      "Epoch 9 [9/13] - Loss: 1.475 - - Time: 0.107\n",
      "Epoch 9 [10/13] - Loss: 1.448 - - Time: 0.108\n",
      "Epoch 9 [11/13] - Loss: 1.456 - - Time: 0.099\n",
      "Epoch 9 [12/13] - Loss: 1.500 - - Time: 0.093\n",
      "Epoch 10 [0/13] - Loss: 1.488 - - Time: 0.102\n",
      "Epoch 10 [1/13] - Loss: 1.529 - - Time: 0.097\n",
      "Epoch 10 [2/13] - Loss: 1.459 - - Time: 0.102\n",
      "Epoch 10 [3/13] - Loss: 1.454 - - Time: 0.099\n",
      "Epoch 10 [4/13] - Loss: 1.485 - - Time: 0.102\n",
      "Epoch 10 [5/13] - Loss: 1.412 - - Time: 0.101\n",
      "Epoch 10 [6/13] - Loss: 1.438 - - Time: 0.105\n",
      "Epoch 10 [7/13] - Loss: 1.489 - - Time: 0.098\n",
      "Epoch 10 [8/13] - Loss: 1.473 - - Time: 0.103\n",
      "Epoch 10 [9/13] - Loss: 1.449 - - Time: 0.103\n",
      "Epoch 10 [10/13] - Loss: 1.439 - - Time: 0.105\n",
      "Epoch 10 [11/13] - Loss: 1.466 - - Time: 0.102\n",
      "Epoch 10 [12/13] - Loss: 1.465 - - Time: 0.105\n",
      "Epoch 11 [0/13] - Loss: 1.436 - - Time: 0.100\n",
      "Epoch 11 [1/13] - Loss: 1.418 - - Time: 0.099\n",
      "Epoch 11 [2/13] - Loss: 1.461 - - Time: 0.109\n",
      "Epoch 11 [3/13] - Loss: 1.502 - - Time: 0.104\n",
      "Epoch 11 [4/13] - Loss: 1.437 - - Time: 0.103\n",
      "Epoch 11 [5/13] - Loss: 1.466 - - Time: 0.102\n",
      "Epoch 11 [6/13] - Loss: 1.435 - - Time: 0.101\n",
      "Epoch 11 [7/13] - Loss: 1.460 - - Time: 0.104\n",
      "Epoch 11 [8/13] - Loss: 1.397 - - Time: 0.103\n",
      "Epoch 11 [9/13] - Loss: 1.492 - - Time: 0.102\n",
      "Epoch 11 [10/13] - Loss: 1.475 - - Time: 0.107\n",
      "Epoch 11 [11/13] - Loss: 1.464 - - Time: 0.102\n",
      "Epoch 11 [12/13] - Loss: 1.536 - - Time: 0.103\n",
      "Epoch 12 [0/13] - Loss: 1.465 - - Time: 0.103\n",
      "Epoch 12 [1/13] - Loss: 1.505 - - Time: 0.102\n",
      "Epoch 12 [2/13] - Loss: 1.453 - - Time: 0.099\n",
      "Epoch 12 [3/13] - Loss: 1.438 - - Time: 0.100\n",
      "Epoch 12 [4/13] - Loss: 1.451 - - Time: 0.102\n",
      "Epoch 12 [5/13] - Loss: 1.481 - - Time: 0.098\n",
      "Epoch 12 [6/13] - Loss: 1.438 - - Time: 0.098\n",
      "Epoch 12 [7/13] - Loss: 1.424 - - Time: 0.101\n",
      "Epoch 12 [8/13] - Loss: 1.459 - - Time: 0.097\n",
      "Epoch 12 [9/13] - Loss: 1.484 - - Time: 0.103\n",
      "Epoch 12 [10/13] - Loss: 1.442 - - Time: 0.100\n",
      "Epoch 12 [11/13] - Loss: 1.449 - - Time: 0.105\n",
      "Epoch 12 [12/13] - Loss: 1.435 - - Time: 0.098\n",
      "Epoch 13 [0/13] - Loss: 1.452 - - Time: 0.103\n",
      "Epoch 13 [1/13] - Loss: 1.429 - - Time: 0.098\n",
      "Epoch 13 [2/13] - Loss: 1.437 - - Time: 0.100\n",
      "Epoch 13 [3/13] - Loss: 1.432 - - Time: 0.098\n",
      "Epoch 13 [4/13] - Loss: 1.418 - - Time: 0.098\n",
      "Epoch 13 [5/13] - Loss: 1.467 - - Time: 0.104\n",
      "Epoch 13 [6/13] - Loss: 1.469 - - Time: 0.102\n",
      "Epoch 13 [7/13] - Loss: 1.463 - - Time: 0.102\n",
      "Epoch 13 [8/13] - Loss: 1.479 - - Time: 0.103\n",
      "Epoch 13 [9/13] - Loss: 1.462 - - Time: 0.108\n",
      "Epoch 13 [10/13] - Loss: 1.503 - - Time: 0.106\n",
      "Epoch 13 [11/13] - Loss: 1.417 - - Time: 0.104\n",
      "Epoch 13 [12/13] - Loss: 1.456 - - Time: 0.099\n",
      "Epoch 14 [0/13] - Loss: 1.448 - - Time: 0.104\n",
      "Epoch 14 [1/13] - Loss: 1.463 - - Time: 0.101\n",
      "Epoch 14 [2/13] - Loss: 1.368 - - Time: 0.104\n",
      "Epoch 14 [3/13] - Loss: 1.441 - - Time: 0.103\n",
      "Epoch 14 [4/13] - Loss: 1.465 - - Time: 0.100\n",
      "Epoch 14 [5/13] - Loss: 1.478 - - Time: 0.100\n",
      "Epoch 14 [6/13] - Loss: 1.421 - - Time: 0.101\n",
      "Epoch 14 [7/13] - Loss: 1.438 - - Time: 0.106\n",
      "Epoch 14 [8/13] - Loss: 1.493 - - Time: 0.103\n",
      "Epoch 14 [9/13] - Loss: 1.464 - - Time: 0.100\n",
      "Epoch 14 [10/13] - Loss: 1.474 - - Time: 0.099\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 14 [11/13] - Loss: 1.433 - - Time: 0.098\n",
      "Epoch 14 [12/13] - Loss: 1.458 - - Time: 0.100\n",
      "Epoch 15 [0/13] - Loss: 1.406 - - Time: 0.101\n",
      "Epoch 15 [1/13] - Loss: 1.477 - - Time: 0.106\n",
      "Epoch 15 [2/13] - Loss: 1.458 - - Time: 0.101\n",
      "Epoch 15 [3/13] - Loss: 1.442 - - Time: 0.103\n",
      "Epoch 15 [4/13] - Loss: 1.463 - - Time: 0.108\n",
      "Epoch 15 [5/13] - Loss: 1.464 - - Time: 0.098\n",
      "Epoch 15 [6/13] - Loss: 1.435 - - Time: 0.102\n",
      "Epoch 15 [7/13] - Loss: 1.426 - - Time: 0.104\n",
      "Epoch 15 [8/13] - Loss: 1.466 - - Time: 0.102\n",
      "Epoch 15 [9/13] - Loss: 1.404 - - Time: 0.110\n",
      "Epoch 15 [10/13] - Loss: 1.458 - - Time: 0.106\n",
      "Epoch 15 [11/13] - Loss: 1.425 - - Time: 0.105\n",
      "Epoch 15 [12/13] - Loss: 1.492 - - Time: 0.105\n",
      "Epoch 16 [0/13] - Loss: 1.511 - - Time: 0.102\n",
      "Epoch 16 [1/13] - Loss: 1.406 - - Time: 0.100\n",
      "Epoch 16 [2/13] - Loss: 1.452 - - Time: 0.097\n",
      "Epoch 16 [3/13] - Loss: 1.431 - - Time: 0.102\n",
      "Epoch 16 [4/13] - Loss: 1.413 - - Time: 0.104\n",
      "Epoch 16 [5/13] - Loss: 1.428 - - Time: 0.106\n",
      "Epoch 16 [6/13] - Loss: 1.468 - - Time: 0.104\n",
      "Epoch 16 [7/13] - Loss: 1.485 - - Time: 0.100\n",
      "Epoch 16 [8/13] - Loss: 1.408 - - Time: 0.100\n",
      "Epoch 16 [9/13] - Loss: 1.418 - - Time: 0.100\n",
      "Epoch 16 [10/13] - Loss: 1.437 - - Time: 0.104\n",
      "Epoch 16 [11/13] - Loss: 1.473 - - Time: 0.102\n",
      "Epoch 16 [12/13] - Loss: 1.460 - - Time: 0.099\n",
      "Epoch 17 [0/13] - Loss: 1.436 - - Time: 0.105\n",
      "Epoch 17 [1/13] - Loss: 1.432 - - Time: 0.101\n",
      "Epoch 17 [2/13] - Loss: 1.504 - - Time: 0.103\n",
      "Epoch 17 [3/13] - Loss: 1.418 - - Time: 0.103\n",
      "Epoch 17 [4/13] - Loss: 1.405 - - Time: 0.102\n",
      "Epoch 17 [5/13] - Loss: 1.460 - - Time: 0.101\n",
      "Epoch 17 [6/13] - Loss: 1.476 - - Time: 0.099\n",
      "Epoch 17 [7/13] - Loss: 1.493 - - Time: 0.105\n",
      "Epoch 17 [8/13] - Loss: 1.422 - - Time: 0.106\n",
      "Epoch 17 [9/13] - Loss: 1.458 - - Time: 0.107\n",
      "Epoch 17 [10/13] - Loss: 1.441 - - Time: 0.108\n",
      "Epoch 17 [11/13] - Loss: 1.449 - - Time: 0.110\n",
      "Epoch 17 [12/13] - Loss: 1.373 - - Time: 0.094\n",
      "Epoch 18 [0/13] - Loss: 1.412 - - Time: 0.099\n",
      "Epoch 18 [1/13] - Loss: 1.477 - - Time: 0.105\n",
      "Epoch 18 [2/13] - Loss: 1.407 - - Time: 0.107\n",
      "Epoch 18 [3/13] - Loss: 1.432 - - Time: 0.103\n",
      "Epoch 18 [4/13] - Loss: 1.456 - - Time: 0.102\n",
      "Epoch 18 [5/13] - Loss: 1.470 - - Time: 0.103\n",
      "Epoch 18 [6/13] - Loss: 1.446 - - Time: 0.100\n",
      "Epoch 18 [7/13] - Loss: 1.450 - - Time: 0.107\n",
      "Epoch 18 [8/13] - Loss: 1.522 - - Time: 0.100\n",
      "Epoch 18 [9/13] - Loss: 1.449 - - Time: 0.099\n",
      "Epoch 18 [10/13] - Loss: 1.411 - - Time: 0.103\n",
      "Epoch 18 [11/13] - Loss: 1.425 - - Time: 0.103\n",
      "Epoch 18 [12/13] - Loss: 1.391 - - Time: 0.099\n",
      "Epoch 19 [0/13] - Loss: 1.419 - - Time: 0.105\n",
      "Epoch 19 [1/13] - Loss: 1.475 - - Time: 0.104\n",
      "Epoch 19 [2/13] - Loss: 1.470 - - Time: 0.098\n",
      "Epoch 19 [3/13] - Loss: 1.429 - - Time: 0.105\n",
      "Epoch 19 [4/13] - Loss: 1.421 - - Time: 0.101\n",
      "Epoch 19 [5/13] - Loss: 1.462 - - Time: 0.099\n",
      "Epoch 19 [6/13] - Loss: 1.407 - - Time: 0.098\n",
      "Epoch 19 [7/13] - Loss: 1.434 - - Time: 0.104\n",
      "Epoch 19 [8/13] - Loss: 1.459 - - Time: 0.105\n",
      "Epoch 19 [9/13] - Loss: 1.415 - - Time: 0.106\n",
      "Epoch 19 [10/13] - Loss: 1.431 - - Time: 0.103\n",
      "Epoch 19 [11/13] - Loss: 1.462 - - Time: 0.104\n",
      "Epoch 19 [12/13] - Loss: 1.440 - - Time: 0.102\n",
      "Epoch 20 [0/13] - Loss: 1.469 - - Time: 0.099\n",
      "Epoch 20 [1/13] - Loss: 1.363 - - Time: 0.098\n",
      "Epoch 20 [2/13] - Loss: 1.409 - - Time: 0.116\n",
      "Epoch 20 [3/13] - Loss: 1.477 - - Time: 0.105\n",
      "Epoch 20 [4/13] - Loss: 1.473 - - Time: 0.104\n",
      "Epoch 20 [5/13] - Loss: 1.428 - - Time: 0.104\n",
      "Epoch 20 [6/13] - Loss: 1.477 - - Time: 0.105\n",
      "Epoch 20 [7/13] - Loss: 1.391 - - Time: 0.107\n",
      "Epoch 20 [8/13] - Loss: 1.453 - - Time: 0.112\n",
      "Epoch 20 [9/13] - Loss: 1.381 - - Time: 0.108\n",
      "Epoch 20 [10/13] - Loss: 1.357 - - Time: 0.107\n",
      "Epoch 20 [11/13] - Loss: 1.370 - - Time: 0.102\n",
      "Epoch 20 [12/13] - Loss: 1.334 - - Time: 0.107\n",
      "Epoch 21 [0/13] - Loss: 1.357 - - Time: 0.111\n",
      "Epoch 21 [1/13] - Loss: 1.349 - - Time: 0.099\n",
      "Epoch 21 [2/13] - Loss: 1.304 - - Time: 0.123\n",
      "Epoch 21 [3/13] - Loss: 1.322 - - Time: 0.097\n",
      "Epoch 21 [4/13] - Loss: 1.325 - - Time: 0.108\n",
      "Epoch 21 [5/13] - Loss: 1.283 - - Time: 0.105\n",
      "Epoch 21 [6/13] - Loss: 1.302 - - Time: 0.116\n",
      "Epoch 21 [7/13] - Loss: 1.238 - - Time: 0.104\n",
      "Epoch 21 [8/13] - Loss: 1.275 - - Time: 0.105\n",
      "Epoch 21 [9/13] - Loss: 1.258 - - Time: 0.100\n"
     ]
    }
   ],
   "source": [
    "num_epochs = 60\n",
    "\n",
    "optimizer = torch.optim.Adam([\n",
    "    {'params': model.parameters()},\n",
    "], lr=0.01)\n",
    "mll = VariationalELBO(model.likelihood, model, train_x.shape[-2])\n",
    "\n",
    "import time\n",
    "\n",
    "with gpytorch.settings.fast_computations(log_prob=False, solves=False):\n",
    "    for i in range(num_epochs):\n",
    "        for minibatch_i, (x_batch, y_batch) in enumerate(train_loader):\n",
    "            start_time = time.time()\n",
    "            optimizer.zero_grad()\n",
    "            \n",
    "            output = model(x_batch)\n",
    "            loss = -mll(output, y_batch)\n",
    "            print('Epoch %d [%d/%d] - Loss: %.3f - - Time: %.3f' % (i + 1, minibatch_i, len(train_loader), loss.item(), time.time() - start_time))\n",
    "\n",
    "            loss.backward()\n",
    "            optimizer.step()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Make predictions and get an RMSE\n",
    "\n",
    "The output distribution of a deep GP in this framework is actually a mixture of `num_samples` Gaussians for each output. We get predictions the same way with all GPyTorch models, but we do currently need to do some reshaping to get the means and variances in a reasonable form.\n",
    "\n",
    "SVGP gets an RMSE of around 0.41 after 60 epochs of training, so overall getting an RMSE of 0.35 out of a 2 layer deep GP without much tuning involved is pretty good!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.3552, device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "import gpytorch\n",
    "model.eval()\n",
    "predictive_means, predictive_variances = model.predict(test_x)\n",
    "\n",
    "rmse = torch.mean(torch.pow(predictive_means.mean(0) - test_y, 2)).sqrt()\n",
    "print(rmse)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
